{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "o4vBVdNx2EPr"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting virtual X frame buffer: Xvfb../xvfb: line 24: start-stop-daemon: command not found\n",
            ".\n"
          ]
        }
      ],
      "source": [
        "import sys, os\n",
        "if 'google.colab' in sys.modules and not os.path.exists('.setup_complete'):\n",
        "    # Install xvfb and our launcher script for it\n",
        "    !apt-get install -y xvfb\n",
        "    !wget -q https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/xvfb -O ../xvfb\n",
        "\n",
        "    # Download dependencies from Github\n",
        "    !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/atari_wrappers.py\n",
        "    !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/env_batch.py\n",
        "    !wget https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week06_policy_based/runners.py\n",
        "\n",
        "    # Update the gym environment to be compatible with the Atari environment\n",
        "    !pip install -q gymnasium[atari,accept-rom-license]\n",
        "    !pip install -q tensorboardX\n",
        "\n",
        "    !touch .setup_complete\n",
        "\n",
        "# This code creates a virtual display to draw game images on.\n",
        "# It will have no effect if your machine has a monitor.\n",
        "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
        "    !bash ../xvfb start\n",
        "    os.environ['DISPLAY'] = ':1'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O_iJbFWQ2EPs"
      },
      "source": [
        "# Implementing Advantage-Actor Critic (A2C)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "16ownLDJ2EPs"
      },
      "source": [
        "In this notebook you will implement Advantage Actor Critic algorithm that trains on a batch of Atari 2600 environments running in parallel.\n",
        "\n",
        "Firstly, we will use environment wrappers implemented in file `atari_wrappers.py`. These wrappers preprocess observations (resize, grayscale, take max between frames, skip frames and stack them together) and rewards. Some of the wrappers help to reset the environment and pass `done` flag equal to `True` when agent dies.\n",
        "File `env_batch.py` includes implementation of `ParallelEnvBatch` class that allows to run multiple environments in parallel. To create an environment we can use `nature_dqn_env` function. Note that if you are using\n",
        "PyTorch and not using `tensorboardX` you will need to implement a wrapper that will log **raw** total rewards that the *unwrapped* environment returns and redefine the implemention of `nature_dqn_env` function here.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uScP-zu12EPt"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import gymnasium as gym\n",
        "from atari_wrappers import nature_dqn_env\n",
        "\n",
        "\n",
        "env_name = \"SpaceInvadersNoFrameskip-v4\"\n",
        "nenvs = 8  # change this if you have more than 8 CPU ;)\n",
        "summaries = \"Tensorboard\"\n",
        "\n",
        "env = nature_dqn_env(env_name, nenvs=nenvs, summaries=summaries)\n",
        "obs, _ = env.reset()\n",
        "assert obs.shape == (nenvs, 4, 84, 84)\n",
        "assert obs.dtype == np.float32\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jiWeYgmd2EPt"
      },
      "source": [
        "Next, we will need to implement a model that predicts logits and values. It is suggested that you use the same model as in [Nature DQN paper](https://www.nature.com/articles/nature14236) with a modification that instead of having a single output layer, it will have two output layers taking as input the output of the last hidden layer. **Note** that this model is different from the model you used in homework where you implemented DQN. You can use your favorite deep learning framework here. We suggest that you use orthogonal initialization with parameter $\\sqrt{2}$ for kernels and initialize biases with zeros."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 30,
      "metadata": {
        "id": "FIkJ7z7TiWS4"
      },
      "outputs": [],
      "source": [
        "# import tensorflow as torch\n",
        "# import torch as tf\n",
        "\n",
        "<YOUR CODE: define your model here>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pA2VlyZ32EPt"
      },
      "source": [
        "You will also need to define and use a policy that wraps the model. While the model computes logits for all actions, the policy will sample actions and also compute their log probabilities.  `policy.act` should return a dictionary of all the arrays that are needed to interact with an environment and train the model.\n",
        " Note that actions must be an `np.ndarray` while the other\n",
        "tensors need to have the type determined by your deep learning framework."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 31,
      "metadata": {
        "id": "dtHP-Fo72EPt"
      },
      "outputs": [],
      "source": [
        "class Policy:\n",
        "    def __init__(self, model):\n",
        "        self.model = model\n",
        "\n",
        "    def act(self, inputs):\n",
        "        # Implement a policy by calling the model, sampling actions and computing their log probs.\n",
        "        # Should return a dict containing keys ['actions', 'logits', 'log_probs', 'values'].\n",
        "\n",
        "        <YOUR CODE>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2oPCQwsd2EPt"
      },
      "source": [
        "Next will pass the environment and policy to a runner that collects partial trajectories from the environment.\n",
        "The class that does is is already implemented for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "id": "fj-fKr_A2EPt"
      },
      "outputs": [],
      "source": [
        "from runners import EnvRunner"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_9JehIbH2EPt"
      },
      "source": [
        "This runner interacts with the environment for a given number of steps and returns a dictionary containing\n",
        "keys\n",
        "\n",
        "* 'observations'\n",
        "* 'rewards'\n",
        "* 'resets'\n",
        "* 'actions'\n",
        "* all other keys that you defined in `Policy`\n",
        "\n",
        "under each of these keys there is a python `list` of interactions with the environment. This list has length $T$ that is size of partial trajectory. Partial trajectory for given moment `t` is part of `ComputeValueTargets.__call__` input argument `trajectory` from moment `t` to the end (i.e. it's different at each iteration in the algorithm)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iY7FB6s72EPu"
      },
      "source": [
        "To train the part of the model that predicts state values you will need to compute the value targets.\n",
        "Any callable could be passed to `EnvRunner` to be applied to each partial trajectory after it is collected.\n",
        "Thus, we can implement and use `ComputeValueTargets` callable.\n",
        "The formula for the value targets is simple:\n",
        "\n",
        "$$\n",
        "\\hat v(s_t) = \\left( \\sum_{t'=0}^{T - 1} \\gamma^{t'}r_{t+t'} \\right) + \\gamma^T \\hat{v}(s_{t+T}),\n",
        "$$\n",
        "\n",
        "In implementation, however, do not forget to use\n",
        "`trajectory['resets']` flags to check if you need to add the value targets at the next step when\n",
        "computing value targets for the current step. You can access `trajectory['state']['latest_observation']`\n",
        "to get last observations in partial trajectory &mdash; $s_{t+T}$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 33,
      "metadata": {
        "id": "4CbDi3GZ2EPu"
      },
      "outputs": [],
      "source": [
        "class ComputeValueTargets:\n",
        "    def __init__(self, policy, gamma=0.99):\n",
        "        self.policy = policy\n",
        "        self.gamma = gamma\n",
        "\n",
        "    def __call__(self, trajectory):\n",
        "        \"\"\"Compute value targets for a given partial trajectory.\"\"\"\n",
        "\n",
        "        # This method should modify trajectory inplace by adding\n",
        "        # an item with key 'value_targets' to it.\n",
        "\n",
        "        <YOUR CODE>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_d9OYyz2EPu"
      },
      "source": [
        "After computing value targets we will transform lists of interactions into tensors\n",
        "with the first dimension `batch_size` which is equal to `env_steps * num_envs`, i.e. you essentially need\n",
        "to flatten the first two dimensions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 34,
      "metadata": {
        "id": "IEnqWlHh2EPu"
      },
      "outputs": [],
      "source": [
        "class MergeTimeBatch:\n",
        "    \"\"\" Merges first two axes typically representing time and env batch. \"\"\"\n",
        "    def __call__(self, trajectory):\n",
        "        # Modify trajectory inplace.\n",
        "\n",
        "        <YOUR CODE>\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 35,
      "metadata": {
        "id": "-2CwwzLl2EPu"
      },
      "outputs": [],
      "source": [
        "model = <YOUR CODE: create your model>\n",
        "policy = Policy(model)\n",
        "runner = EnvRunner(\n",
        "    env=env,\n",
        "    policy=policy,\n",
        "    nsteps=5,\n",
        "    transforms=[\n",
        "        ComputeValueTargets(policy),\n",
        "        MergeTimeBatch(),\n",
        "    ],\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IuYy-8Ri2EPu"
      },
      "source": [
        "Now is the time to implement the advantage actor critic algorithm itself. You can look into your lecture,\n",
        "[Mnih et al. 2016](https://arxiv.org/abs/1602.01783) paper, and [lecture](https://www.youtube.com/watch?v=Tol_jw5hWnI&list=PLkFD6_40KJIxJMR-j5A1mkxK26gh_qg37&index=20) by Sergey Levine."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "metadata": {
        "id": "hxFLzyRX2EPu"
      },
      "outputs": [],
      "source": [
        "class A2C:\n",
        "    def __init__(self,\n",
        "                 policy,\n",
        "                 optimizer,\n",
        "                 value_loss_coef=0.25,\n",
        "                 entropy_coef=0.01,\n",
        "                 max_grad_norm=0.5):\n",
        "        self.policy = policy\n",
        "        self.optimizer = optimizer\n",
        "        self.value_loss_coef = value_loss_coef\n",
        "        self.entropy_coef = entropy_coef\n",
        "        self.max_grad_norm = max_grad_norm\n",
        "\n",
        "    def policy_loss(self, trajectory):\n",
        "        # You will need to compute advantages here.\n",
        "        <YOUR CODE>\n",
        "\n",
        "    def value_loss(self, trajectory):\n",
        "        <YOUR CODE>\n",
        "\n",
        "    def loss(self, trajectory):\n",
        "        <YOUR CODE>\n",
        "\n",
        "    def step(self, trajectory):\n",
        "        <YOUR CODE>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JIMtFZuG2EPu"
      },
      "source": [
        "Now you can train your model. With reasonable hyperparameters training on a single GTX1080 for 10 million steps across all batched environments (which translates to about 5 hours of wall clock time)\n",
        "it should be possible to achieve *average raw reward over last 100 episodes* (the average is taken over 100 last\n",
        "episodes in each environment in the batch) of about 600. You should plot this quantity with respect to\n",
        "`runner.step_var` &mdash; the number of interactions with all environments. It is highly\n",
        "encouraged to also provide plots of the following quantities (these are useful for debugging as well):\n",
        "\n",
        "* [Coefficient of Determination](https://en.wikipedia.org/wiki/Coefficient_of_determination) between\n",
        "value targets and value predictions\n",
        "* Entropy of the policy $\\pi$\n",
        "* Value loss\n",
        "* Policy loss\n",
        "* Value targets\n",
        "* Value predictions\n",
        "* Gradient norm\n",
        "* Advantages\n",
        "* A2C loss\n",
        "\n",
        "For optimization we suggest you use RMSProp with learning rate starting from 7e-4 and linearly decayed to 0, smoothing constant (alpha in PyTorch and decay in TensorFlow) equal to 0.99 and epsilon equal to 1e-5."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "#if you use TensorboardSummaries\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir logs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "a2c = <YOUR CODE: create an instance of the algorithm>\n",
        "\n",
        "<YOUR CODE: write a training loop>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZDbgUdMq2EPu"
      },
      "source": [
        "### Target networks?\n",
        "\n",
        "You may recall a technique called \"target networks\" we used a few weeks ago when we trained a DQN agent to play Atari Breakout and wonder why we have not suggested using them here. The answer is that this is more historical than practical.\n",
        "\n",
        "While the \"chasing the target\" problem is still present in actor-critic value estimation and target networks do show up in follow-up papers, the original A3C/A2C papers do not mention them and do not explain this omission.\n",
        "\n",
        "The hypothesis why this may not be a big deal (compared to Q-learning) goes like this. An A3C/A2C agent selects actions based on policy, not an epsilon greedy exploration function, for which the argmax can change drastically due to tiny errors in function approximation. Therefore, errors in the value target caused by target chasing will cause less damage.\n",
        "\n",
        "Also, the actor-critic gradient relies on the advantage function $A(s_t, a_t) = Q(s_t, a_t) - V(s_t)$. Compare this to the $Q$-function $Q(s_t, a_t) = r(s_t, a_t) + \\gamma \\cdot \\mathbb{E}_{s_{t+1} \\mid s_t, a_t} V(s_{t+1})$ used in Q-learning and SARSA: we would expect that any bias in $V$-function approximation will be carried over from $V(s_{t+1})$ to $V(s_t)$ by gradient updates. However, in the formula for the advantage function the two approximations ($Q$-function and $V$-function) come with opposite signs, and thus the errors will cancel out.\n",
        "\n",
        "The last reason may be computational. Authors were concerned to beat existent algorithms in the wall-clock learning time, and any overhead of parameter copying (target network update) counted against this goal."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
